import numpy as np
import copy

class HindsightFilter():
    def __init__(self):
        self.target_idx = -1
        pass

    def __call__(self, trajectory):
        raise NotImplementedError

IGNORE_FIRST = 3 # ignore the first few states since weird stuff can happen there in the simulator

class NonPassiveGraphHindsightFilter(HindsightFilter):
    def __init__(self, passive_graph,  min_non_passive, target_idx, target_graph_idx):
        self.passive_graph = passive_graph # if target_idx >= 0, this should be the passive graph of the target object
        self.min_non_passive = min_non_passive
        self.target_idx = target_idx
        self.target_graph_idx = target_graph_idx
    
    def _graph_comparison(self, graph):
        graph_diffs = graph - self.passive_graph
        if self.target_graph_idx >= 0: 
            per_graphs = graph_diffs[:,self.target_graph_idx]
            # per_graphs = graph_diffs[IGNORE_FIRST:,self.target_graph_idx]
        else:
            graph_diffs = graph - self.passive_graph
            # graph_diffs = graph[IGNORE_FIRST:] - self.passive_graph
            per_graphs = np.sum(graph_diffs, axis=-2)
        graph_totals = np.sum(per_graphs, axis=-1)
        graph_totals[graph_totals > 0] = 1
        reach_graphs = np.nonzero(graph_totals)
        reached_indexes = list()
        if len(reach_graphs[0]) > 0:
            reached_indexes = np.nonzero(graph_totals)[0] # use the first non-passive graph as the reached index
            # reached_indexes = np.nonzero(graph_totals)[0] + IGNORE_FIRST # use the first non-passive graph as the reached index
        return np.sum(graph_totals.astype(int)) >= self.min_non_passive, graph_totals, reached_indexes

    def __call__(self, trajectory):
        success, graph_totals, reached_indexes = self._graph_comparison(trajectory.graph)
        true_success, true_graph_totals, true_reached_idxes = self._graph_comparison(trajectory.true_graph)
        return success, reached_indexes, np.sum(graph_totals), true_success, true_graph_totals, true_reached_idxes

class ControlHindsightFilter(HindsightFilter):
    def __init__(self, passive_graph,  min_non_passive, target_idx, target_graph_idx, ctrl_idx, dist_test=0.0, vel_test=0.0):
        self.passive_graph = passive_graph # if target_idx >= 0, this should be the passive graph of the target object
        self.min_non_passive = min_non_passive
        self.target_idx = target_idx
        self.target_graph_idx = target_graph_idx
        self.ctrl_idx = ctrl_idx
        self.dist_test = dist_test
        self.vel_test = vel_test

    def _graph_comparison(self, graph):
        # graph_diffs = graph - self.passive_graph
        control_interaction = graph[IGNORE_FIRST:, self.target_graph_idx, self.ctrl_idx]
        # print(graph.shape, self.target_graph_idx, self.ctrl_idx, np.sum(graph_diffs[IGNORE_FIRST:],axis=0), np.sum(control_interaction))
        reached_indexes = np.nonzero(control_interaction)[0] + IGNORE_FIRST
        reached = np.sum(control_interaction.astype(int)) >= self.min_non_passive
        return reached, reached_indexes, np.sum(control_interaction.astype(int))


    def __call__(self, trajectory):
        success, idxes, graph_count = self._graph_comparison(trajectory.graph)
        true_success, true_idxes, true_graph_count = self._graph_comparison(trajectory.true_graph)
        # TODO: tests hardcoded indices right now
        # print(success,true_success, graph_count,true_graph_count, np.min(trajectory.target_diff[IGNORE_FIRST:,self.target_graph_idx,2]), success and np.min(trajectory.target_diff[IGNORE_FIRST:,self.target_graph_idx,2]) < self.vel_test)
        if self.dist_test != 0: success = success and np.linalg.norm(trajectory.target_diff[IGNORE_FIRST:,self.target_graph_idx]) > self.dist_test
        if self.vel_test != 0: success = success and np.min(trajectory.target_diff[IGNORE_FIRST:,self.target_graph_idx,2]) < self.vel_test
        return success, idxes, graph_count, true_success, true_idxes, true_graph_count

class ActionGraphHindsightFilter(HindsightFilter):
    def __init__(self, action_index, target_index, target_graph_idx, passive_graph, min_non_passive):
        self.action_index = action_index
        self.target_graph_idx = target_graph_idx
        self.target_idx = target_index
        self.passive_graph = passive_graph
        self.min_non_passive = min_non_passive
    
    def _graph_comparison(self, graph):
        control_idxes = np.array([self.action_index])
        reached_indexes = list()
        reached = False
        graph_totals = 0
        target_graph_idx_positive = self.target_graph_idx + graph.shape[1] + 1 if self.target_graph_idx < 0 else self.target_graph_idx
        for i, g in enumerate(graph[IGNORE_FIRST:]):
            new_control_idxes = np.nonzero(g[:,control_idxes])[0] + 1
            control_idxes = np.array(np.unique(control_idxes.tolist() + new_control_idxes.tolist()))
            target_inside = target_graph_idx_positive in control_idxes
            if target_inside: 
                reached = True
                graph_totals += np.sum(new_control_idxes)
            if reached and len(control_idxes) > 0: reached_indexes.append(i + IGNORE_FIRST)
        # graph_diffs = graph - self.passive_graph
        # print(self.target_graph_idx, self.action_index, target_graph_idx_positive, reached, control_idxes, np.sum(graph_diffs[IGNORE_FIRST:],axis=0))
        return graph_totals >= self.min_non_passive, graph_totals, reached_indexes
    
    def __call__(self, trajectory):
        success, graph_totals, reached_indexes = self._graph_comparison(trajectory.graph)
        true_success, true_graph_totals, true_reached_idxes = self._graph_comparison(trajectory.true_graph)
        return success, reached_indexes, np.sum(graph_totals), true_success, true_reached_idxes, np.sum(true_graph_totals)

class AllGraphHindsightFilter(HindsightFilter):
    # just always returns true
    def __init__(self, passive_graph,  min_non_passive, target_idx, target_graph_idx):
        self.passive_graph = passive_graph # if target_idx >= 0, this should be the passive graph of the target object
        self.min_non_passive = min_non_passive
        self.target_idx = target_idx
        self.target_graph_idx = target_graph_idx

    def __call__(self, trajectory):
        return True, [0], len(trajectory), True, [0], len(trajectory)
